import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
import sys
import random
import numpy as np
import copy

import torch
import torch.optim as optim
import torch.nn.functional as F
from models.r2d2_2heads import OBLR2D2Agent, DIAL_OBLR2D2Agent
from utils.memory import Memory, LocalBuffer, OBLMemory, OBLLocalBuffer, MIOBLMemory, MIOBLLocalBuffer, DIALRandomMemory, DialLocalBuffer
from tensorboardX import SummaryWriter

from models.r2d2_config import initial_exploration, batch_size, dial_batch_size, update_target, log_interval, eval_argmax, eval_interval, device, replay_memory_capacity, lr, dial_lr, dial_iql_eps, sequence_length, local_mini_batch, dial_local_mini_batch, use_mi_loss
from utils.pbmaze_config import env_config, iql_env_config
from phone_booth_collab_maze import PBCMaze
from pbcmaze_belief_model import ReceiverBeliefModel, SenderBeliefModel
from collections import deque

RESULT_PATH = "results/"
MODEL_PATH = "trained_models/"
NUM_RUNS = 5

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

def evaluate(eval_env, a0_agent, a1_agent, mode = "iql"):
    done = False
    score = 0
    steps = 0
    a0_reward = None
    a1_reward = None
    obs, state = eval_env.reset()
    a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
    a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
    with torch.no_grad():
        a1_actions = []
        a0_actions = []
        while not done:
            # Agent 0's turn
            a0_obs = torch.Tensor(eval_env.get_obs(0)).to(device)
            a0_policy, a0_action, a0_next_hidden = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, mode = mode)
            a0_reward, done, info = eval_env.step(0, a0_action, a0_policy.squeeze().detach().numpy())
            a0_actions.append(a0_action)
            a0_hidden = a0_next_hidden
            # Agent 1's turn
            a1_obs = torch.Tensor(eval_env.get_obs(1)).to(device)
            a1_policy, a1_action, a1_next_hidden = a1_agent.get_action(a1_obs, a1_hidden, argmax = True, mode = mode)
            a1_actions.append(a1_action)
            a1_reward, done, info = eval_env.step(1, a1_action)
            a1_hidden = a1_next_hidden
            score += a0_reward + a1_reward
        #print(a1_actions)

    return score, info

def evaluate_mixed_policy(eval_env, a0_agent, a1_agent, stage_2_training, thres = 0.5):
    done = False
    score = 0
    steps = 0
    a0_reward = None
    a1_reward = None
    obs, state = eval_env.reset()

    a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
    a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))

    with torch.no_grad():
        a0_actions = []
        a1_actions = []
        last_mi_reward = 0.0
        while not done:
            # Agent 0's turn
            a0_obs = torch.Tensor(eval_env.get_obs(0)).to(device)
            # Check for MI reward. If greater than threshold, use utilization policy
            if(stage_2_training == False):
                a0_policy, a0_action, a0_next_hidden = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, mode = "obl")

                a0_reward, done, info = eval_env.step(0, a0_action, a0_policy.squeeze().detach().numpy())
                a0_hidden = a0_next_hidden
                # Agent 1's turn
                a1_obs = torch.Tensor(eval_env.get_obs(1)).to(device)
                a1_policy, a1_action, a1_next_hidden = a1_agent.get_action(a1_obs, a1_hidden, argmax = True, mode = "obl")
                a1_reward, done, info = eval_env.step(1, a1_action)

            else:
                # if(last_mi_reward > thres):
                #     a0_policy, a0_action, a0_next_hidden = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, mode = "iql")
                # else:
                #     a0_policy, a0_action, a0_next_hidden = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, mode = "obl")
                #     _, mi = eval_env.calculate_mi_reward(a0_policy.squeeze().detach().numpy(), 0 , None)
                #     last_mi_reward = mi
                a0_policy, a0_action, a0_next_hidden = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, mode = "iql")
                a0_reward, done, info = eval_env.step(0, a0_action, a0_policy.squeeze().detach().numpy())
                a0_hidden = a0_next_hidden
                # Agent 1's turn
                a1_obs = torch.Tensor(eval_env.get_obs(1)).to(device)
                # if(last_mi_reward > thres):
                #     a1_policy, a1_action, a1_next_hidden = a1_agent.get_action(a1_obs, a1_hidden, argmax = True, mode = "iql")
                # else:
                #     a1_policy, a1_action, a1_next_hidden = a1_agent.get_action(a1_obs, a1_hidden, argmax = True, mode = "obl")
                a1_policy, a1_action, a1_next_hidden = a1_agent.get_action(a1_obs, a1_hidden, argmax = True, mode = "iql")
                # if(mi > thres):
                #     a1_policy, a1_action, a1_next_hidden = a1_agent.get_action(a1_obs, a1_hidden, argmax = True, mode = "iql")
                a1_reward, done, info = eval_env.step(1, a1_action)
            a0_actions.append(a0_action)
            a1_actions.append(a1_action)
            a1_hidden = a1_next_hidden
            score += a0_reward + a1_reward
    print("a0 actions: {}".format(a0_actions))
    print("a1 actions: {}".format(a1_actions))
    return score, info

def evaluate_policy(a0_agent, a1_agent, mode = "iql"):
    with torch.no_grad():
        env = PBCMaze(env_args=iql_env_config)
        env.reset()
        a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
        a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
        while(env.agent0_loc[0] != env.booth_loc or env.agent1_loc[0] != env.receiver_booth_loc):
            a0_obs = torch.Tensor(env.get_obs(0)).to(device)
            _, _, a0_hidden = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, mode = mode)
            a1_obs = torch.Tensor(env.get_obs(1)).to(device)
            _, _, a1_hidden = a1_agent.get_action(a1_obs, a1_hidden, argmax = True, mode = mode)
            env.step(0, 1)
            env.step(1, 0)
        env.goal = 2
        a0_obs = torch.Tensor(env.get_obs(0)).to(device)
        a1_obs = torch.Tensor(env.get_obs(1)).to(device)
        a0_policy, _, _ = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, mode = mode)
        a1_policy, _, _ = a1_agent.get_action(a1_obs, a1_hidden, argmax = True, mode = mode)
        print("Goal: UP")
        print(a0_policy)
        print(a1_policy)
        env.goal = 3
        a0_obs = torch.Tensor(env.get_obs(0)).to(device)
        a1_obs = torch.Tensor(env.get_obs(1)).to(device)
        a0_policy, _, _ = a0_agent.get_action(a0_obs, a0_hidden, argmax = True, mode = mode)
        a1_policy, _, _ = a1_agent.get_action(a1_obs, a1_hidden, argmax = True, mode = mode)
        print("Goal: DOWN")
        print(a0_policy)
        print(a1_policy)

def main():
    sender_time_to_booth_result = []
    receiver_time_to_booth_result = []
    reward_result = []
    eval_reward_result = []
    running_reward_result = []
    runnning_eval_reward_result = []
    for run_idx in range(NUM_RUNS):
        print("Run: " + str(run_idx + 1))
        # Set seed
        set_seed(run_idx)

        # Env
        num_episodes = 25000
        num_episodes_for_mi_training = 0
        stage_2_training = False

        env = PBCMaze(env_args=env_config)
        env.reset()
        eval_env = PBCMaze(env_args=env_config)
        eval_env.reset()
        eval_env.load_env_config(env.save_env_config())
        eval_env.use_mi_shaping = False
        eval_env.use_intermediate_reward = False

        """
        Agent 0 obs: ((channel, width, height), goal feature)
        Agent 1 obs: ((channel, width, height), communication token)
        """
        a0_input_shape  = env.get_obs_size(0)
        a1_input_shape = env.get_obs_size(1)
        a0_num_actions = 7
        a1_num_actions = 5

        receiver_pi_0 = [0.2, 0.2, 0.2, 0.2, 0.2]
        sender_pi_0 = [1/7, 1/7, 1/7, 1/7, 1/7, 1/7, 1/7]
        rb_model = ReceiverBeliefModel(receiver_pi_0, env)
        sb_model = SenderBeliefModel(sender_pi_0, env)
        if(use_mi_loss):
            a0_agent = DIAL_OBLR2D2Agent(a0_input_shape, a0_num_actions, Memory(replay_memory_capacity), LocalBuffer(), MIOBLMemory(replay_memory_capacity), MIOBLLocalBuffer(), DIALRandomMemory(replay_memory_capacity), DialLocalBuffer(), lr, batch_size, dial_batch_size, device, 0, rb_model, use_mi_loss, sigma = 2.0)
        else:
            a0_agent = DIAL_OBLR2D2Agent(a0_input_shape, a0_num_actions, Memory(replay_memory_capacity), LocalBuffer(), OBLMemory(replay_memory_capacity), OBLLocalBuffer(), DIALRandomMemory(replay_memory_capacity), DialLocalBuffer(), lr, batch_size, dial_batch_size, device, 0, rb_model, use_mi_loss, sigma = 2.0)
        a1_agent = OBLR2D2Agent(a1_input_shape, a1_num_actions, Memory(replay_memory_capacity), LocalBuffer(), OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device, 1, sb_model)

        writer = SummaryWriter('logs')

        running_score = 0
        running_eval_score = 0
        epsilon = 1.0
        steps = 0
        loss = 0
        per_run_sender_time_to_booth_list = []
        per_run_receiver_time_to_booth_list = []
        per_run_reward = []
        per_run_eval_reward = []
        per_run_running_reward = []
        per_run_running_eval_reward = []
        # prev_saved_pre_stage_2_obl_weight = copy.deepcopy(a1_agent.online_net.obl_fc.weight)
        # same_weight = prev_saved_pre_stage_2_obl_weight.ne(a1_agent.online_net.obl_fc.weight.data).sum().item() == 0
        # print("same weight: {}".format(same_weight))
        for e in range(num_episodes):
            done = False

            score = 0
            a0_reward = None
            a1_reward = None
            obs, state = env.reset()

            a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
            a1_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
            a1_next_hidden = None

            # Dial transitions [agent 0 time t transition, agent 1 time t + 1 transition, agent 0 time t + 2 transition, agent 1 state and hidden ]
            dial_transitions = []
            eps_step = 0
            while not done:
                steps += 1

                if(len(dial_transitions) == 2):
                    dial_transitions.append(both_in_booth_flag)


                a0_obs = torch.Tensor(env.get_obs(0)).to(device)
                # get action from discovery policy
                if(stage_2_training):
                    a0_policy, a0_action, a0_next_hidden = a0_agent.get_action(a0_obs, a0_hidden, mode = "obl", eps = dial_iql_eps)
                else:
                    a0_policy, a0_action, a0_next_hidden = a0_agent.get_action(a0_obs, a0_hidden, mode = "obl")

                # OBL Sampling - Stage 1
                a0_curr_env_config = env.save_env_config()
                if(stage_2_training == False):
                    a0_agent.obl_sampling(a0_hidden, a0_next_hidden, a0_policy.squeeze().detach().numpy(), a0_action, a0_curr_env_config, a1_agent, a1_next_hidden if a1_next_hidden != None else a1_hidden)
                if len(a0_agent.local_buffer.memory) == local_mini_batch:
                    a0_agent.push_to_memory()

                # Agent 0's turn
                a0_reward, done, info = env.step(0, a0_action, a0_policy.squeeze().detach().numpy())
                both_in_booth_flag = env.sender_in_booth and env.receiver_in_booth
                mask = 0 if done else 1

                # Add to agent 1's IQL buffer - Stage 2
                if(stage_2_training):
                    if(a1_reward != None):
                        # Add to agent 1's buffer
                        next_a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                        a1_agent.iql_buffer.push(a1_obs, next_a1_obs, a1_action, a1_reward + a0_reward, mask, a1_hidden)
                        if len(a1_agent.iql_buffer.memory) == local_mini_batch:
                            a1_agent.push_to_iql_memory()
                    same_weight = prev_saved_pre_stage_2_obl_weight.ne(a1_agent.online_net.obl_fc.weight.data).sum().item() == 0

                    # Agent 1's IQL learning
                    if steps > initial_exploration and len(a1_agent.iql_memory) > batch_size:
                        loss, td_error = a1_agent.train_iql_model()

                        if steps % update_target == 0:
                            a1_agent.update_target_model()
                    dial_transitions.append([a0_obs, a0_hidden, a0_action, torch.Tensor(env.get_obs(0)).to(device), a0_next_hidden, a0_reward, mask])
                    eps_step += 1

                # Update after a0 has taken an action
                if(a1_next_hidden != None):
                    a1_hidden = a1_next_hidden

                # Agent 0's OBL learning
                if len(a0_agent.memory) > batch_size:
                    if(stage_2_training == False):
                        loss, td_error = a0_agent.train_model(obl = True, use_mi_loss = use_mi_loss)
                    if steps % update_target == 0:
                        a0_agent.update_target_model()

                # DIAL training
                if len(a0_agent.dial_memory) > dial_batch_size:
                    a0_agent.train_model_dial(a1_agent, mode = "iql")

                # Update belief
                a1_agent.belief_model.update_belief(comm_token = env.comm_token)

                a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                # get action from discovery policy
                if(stage_2_training):
                    a1_policy, a1_action, a1_next_hidden = a1_agent.get_action(a1_obs, a1_hidden, mode = "obl", eps = dial_iql_eps)
                else:
                    a1_policy, a1_action, a1_next_hidden = a1_agent.get_action(a1_obs, a1_hidden, mode = "obl")

                a1_curr_env_config = env.save_env_config()
                # OBL Sampling - Stage 1
                if(stage_2_training == False):
                    a1_agent.obl_sampling(a1_hidden, a1_next_hidden, a1_policy.squeeze().detach().numpy(), a1_action, a1_curr_env_config, a0_agent, a0_next_hidden)
                if len(a1_agent.local_buffer.memory) == local_mini_batch:
                    a1_agent.push_to_memory()

                assert both_in_booth_flag == (env.sender_in_booth and env.receiver_in_booth)

                # Agent 1's turn
                a1_reward, done, info = env.step(1, a1_action)
                both_in_booth_flag = env.sender_in_booth and env.receiver_in_booth

                # Add to agent 0's IQL buffer and Agent 1's if done + DIAL
                if(stage_2_training):
                    mask = 0 if done else 1
                    next_a0_obs = torch.Tensor(env.get_obs(0)).to(device)
                    a0_agent.iql_buffer.push(a0_obs, next_a0_obs, a0_action, a0_reward + a1_reward, mask, a0_hidden)
                    if len(a0_agent.iql_buffer.memory) == local_mini_batch:
                        a0_agent.push_to_iql_memory()

                    if(done):
                        # Need to add to a1's buffer
                        next_a1_obs = torch.Tensor(env.get_obs(1)).to(device)
                        a1_agent.iql_buffer.push(a1_obs, next_a1_obs, a1_action, a1_reward, mask, a1_hidden)
                        if len(a1_agent.iql_buffer.memory) == local_mini_batch:
                            a1_agent.push_to_iql_memory()

                    # Agent 0's IQL learning
                    if  len(a0_agent.iql_memory) > batch_size:
                        # loss, td_error = a0_agent.train_iql_model()
                        if steps % update_target == 0:
                            a0_agent.update_target_model()

                    dial_transitions.append([a1_obs, a1_hidden, a1_action, torch.Tensor(env.get_obs(1)).to(device), a1_next_hidden, a1_reward, mask])

                    if(len(dial_transitions) % 5 == 0 or done):
                        # compute target to add to buffer
                        a0_agent.dial_push_to_local_buffer(dial_transitions, a1_agent)
                        dial_transitions = dial_transitions[3 : ]
                    #same_weight = prev_saved_pre_stage_2_obl_weight.ne(a1_agent.online_net.obl_fc.weight.data).sum().item() == 0

                a0_hidden = a0_next_hidden

                if(len(a0_agent.dial_buffer.memory) == dial_local_mini_batch):
                    a0_agent.push_to_dial_memory(a1_agent)

                # Agent 1's OBL learning
                if steps > initial_exploration and len(a1_agent.memory) > batch_size:
                    if(stage_2_training == False):
                        # Only train obl in stage 1
                        loss, td_error = a1_agent.train_model(obl = True)
                    if steps % update_target == 0:
                        a1_agent.update_target_model()

                # Update belief
                a0_agent.belief_model.update_belief()

                score += a0_reward + a1_reward
                # print("a0 reward: {:.2f}, a1 reward: {:.2f}".format(a0_reward, a1_reward))
                assert both_in_booth_flag == (env.sender_in_booth and env.receiver_in_booth)

            running_score = 0.99 * running_score + 0.01 * score
            # Steps to phone booth
            if(eval_argmax):
                if e % eval_interval == 0:
                    eval_score, info = evaluate_mixed_policy(eval_env, a0_agent, a1_agent, stage_2_training, 0.5)
                    #eval_score, info = evaluate(eval_env, a0_agent, a1_agent, mode = "obl" if stage_2_training == False else "iql")
                    #print("senderToPB: {} | senderToPB mixed: {}".format(info["sender_time_to_booth"], info_mixed["sender_time_to_booth"]))
                    evaluate_policy(a0_agent, a1_agent, mode = "obl" if stage_2_training == False else "iql")
                    running_eval_score = 0.99 * running_eval_score + 0.01 * eval_score
                    sender_time_to_pb = info["sender_time_to_booth"]
                    receiver_time_to_pb = info["receiver_time_to_booth"]
                    per_run_sender_time_to_booth_list.append(sender_time_to_pb)
                    per_run_receiver_time_to_booth_list.append(receiver_time_to_pb)
                    per_run_eval_reward.append(eval_score)
                    per_run_running_eval_reward.append(running_eval_score)
                    print('Run {} | {} episode |score: {:.2f} | eval score: {:.2f} | reward sum: {:.2f} | SenderToPB: {:.2f} | ReceiverToPB: {:.2f}'.format(
                        run_idx + 1, e, running_score, running_eval_score, eval_score, sender_time_to_pb, receiver_time_to_pb))
                    sys.stdout.flush()
            else:
                sender_time_to_pb = info["sender_time_to_booth"]
                receiver_time_to_pb = info["receiver_time_to_booth"]
                per_run_sender_time_to_booth_list.append(sender_time_to_pb)
                per_run_receiver_time_to_booth_list.append(receiver_time_to_pb)
                per_run_reward.append(score)
                per_run_running_reward.append(running_score)
                if e % log_interval == 0:
                    print('Run {} | {} episode | score: {:.2f} | reward sum: {:.2f} | SenderToPB: {:.2f} | ReceiverToPB: {:.2f}'.format(
                        run_idx + 1, e, running_score, score, sender_time_to_pb, receiver_time_to_pb))
                    writer.add_scalar('log/score', float(running_score), e)
                    writer.add_scalar('log/loss', float(loss), e)
                    sys.stdout.flush()


            # Reset belief
            a0_agent.belief_model.reset_belief()
            a1_agent.belief_model.reset_belief()

            # turn off mi training
            if((e + 1) >= num_episodes_for_mi_training and stage_2_training == False):
                env.turn_off_mi_training()
                stage_2_training = True
                prev_saved_pre_stage_2_obl_weight = copy.deepcopy(a1_agent.online_net.obl_fc.weight)
                eval_score, info = evaluate_mixed_policy(eval_env, a0_agent, a1_agent, stage_2_training, 0.5)
                # Weight sync
                a0_agent.online_net.obl_iql_weight_sync()
                a0_agent.target_net.obl_iql_weight_sync()
                a1_agent.online_net.obl_iql_weight_sync()
                a1_agent.target_net.obl_iql_weight_sync()
                # Freeze discovery policy
                a0_agent.online_net.set_layers_training_mode(["backbone", "obl"], False)
                a0_agent.target_net.set_layers_training_mode(["backbone", "obl"], False)
                a1_agent.online_net.set_layers_training_mode(["backbone", "obl"], False)
                a1_agent.target_net.set_layers_training_mode(["backbone", "obl"], False)
                # Reset optimizers
                a0_agent.reset_optimizer()
                a1_agent.reset_optimizer()
                print("Stage 2 training starts, freeze done")
                sys.stdout.flush()

        sender_time_to_booth_result.append(per_run_sender_time_to_booth_list)
        receiver_time_to_booth_result.append(per_run_receiver_time_to_booth_list)
        if(eval_argmax):
            eval_reward_result.append(per_run_eval_reward)
            runnning_eval_reward_result.append(per_run_running_eval_reward)
        else:
            reward_result.append(per_run_reward)
            running_reward_result.append(per_run_running_reward)

    # Save results
    if not os.path.exists(RESULT_PATH):
        os.makedirs(RESULT_PATH)
    if not os.path.exists(MODEL_PATH):
        os.makedirs(MODEL_PATH)

    if(env_config['use_mi_shaping'] or use_mi_loss):
        sender_result_filename = "obl_dial_sender_time_to_pb" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "") + ("_argmax" if eval_argmax else "") + ".npy"
        receiver_result_filename = "obl_dial_receiver_time_to_pb" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "") + ("_argmax" if eval_argmax else "") + ".npy"
        reward_result_filename = "obl_dial_reward" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "") + ("_argmax" if eval_argmax else "") + ".npy"
        running_reward_result_filename = "obl_dial_running_reward" + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "") + ("_argmax" if eval_argmax else "") + ".npy"
        sender_model_path = MODEL_PATH + "obl_dial_sender_model " + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "") + ("_argmax" if eval_argmax else "")
        receiver_model_path = MODEL_PATH + "obl_dial_receiver_model " + ("_mi_log2" if env_config['use_mi_shaping'] else "") + ("_mi_loss" if use_mi_loss else "") + ("_argmax" if eval_argmax else "")

    elif(env_config['use_intermediate_reward']):
        sender_result_filename = "obl_dial_sender_time_to_pb" + "_ir" + ("_argmax" if eval_argmax else "") + ".npy"
        receiver_result_filename = "obl_dial_receiver_time_to_pb" + "_ir" + ("_argmax" if eval_argmax else "") + ".npy"
        reward_result_filename = "obl_dial_reward" + "_ir"  + ("_argmax" if eval_argmax else "") + ".npy"
        running_reward_result_filename = "obl_dial_running_reward" + "_ir"  + ("_argmax" if eval_argmax else "") + ".npy"
        sender_model_path = MODEL_PATH + "obl_dial_sender_model " + "_ir" + ("_argmax" if eval_argmax else "")
        receiver_model_path = MODEL_PATH + "obl_dial_receiver_model " + "_ir" + ("_argmax" if eval_argmax else "")

    else:
        sender_result_filename = "obl_dial_sender_time_to_pb" + ("_argmax" if eval_argmax else "") + ".npy"
        receiver_result_filename = "obl_dial_receiver_time_to_pb" + ("_argmax" if eval_argmax else "") + ".npy"
        reward_result_filename = "obl_dial_reward" + ("_argmax" if eval_argmax else "") + ".npy"
        running_reward_result_filename = "obl_dial_running_reward" + ("_argmax" if eval_argmax else "") + ".npy"
        sender_model_path = MODEL_PATH + "obl_dial_sender_model " + ("_argmax" if eval_argmax else "")
        receiver_model_path = MODEL_PATH + "obl_dial_receiver_model " + ("_argmax" if eval_argmax else "")

    if(eval_argmax):
        np.save(RESULT_PATH + reward_result_filename, np.array(eval_reward_result))
        np.save(RESULT_PATH + running_reward_result_filename, np.array(runnning_eval_reward_result))
    else:
        np.save(RESULT_PATH + reward_result_filename, np.array(reward_result))
        np.save(RESULT_PATH + running_reward_result_filename, np.array(running_reward_result))

    np.save(RESULT_PATH + sender_result_filename, np.array(sender_time_to_booth_result))
    np.save(RESULT_PATH + receiver_result_filename, np.array(receiver_time_to_booth_result))

    # Save model
    a0_agent.save_model(sender_model_path)
    a1_agent.save_model(receiver_model_path)

if __name__=="__main__":
    main()
